'''
Description: this code is used for generating 
             training/evaluating data of ocr 
             and caculating ocr accuracy.
Author: KouXichao
Date: 2020-10-31 21:14:59
LastEditors: KouXichao
LastEditTime: 2020-12-30 14:44:54
FilePath: /LayoutGen/tools/eval_rec_acc.py
Github: https://github.com/kouxichao
'''

import os
import argparse
import numpy as np
import cv2
# import Levenshtein
import math
import distorsion_generator

class SynthLayout():
    '''
    description: 合成文档数据处理
    param {*} layout_folder: 合成文档数据集路径
    '''
    def __init__(self, layout_folder, warp=False):
        super(SynthLayout, self).__init__()
        self.layout_folder = layout_folder
        self.line_image_dir = os.path.join(layout_folder, 'line_image/')
        self.warp = warp
        self.images = np.load(os.path.join(layout_folder, 'image_names.npy'))
        self.images = self.images.tolist()
        self.all_bboxes = []
        for index in range(len(self.images)):
            try:
                bboxes = np.load(os.path.join(self.layout_folder, 'regionInfo/' + self.images[index].split('.')[0] + '/bboxes.npy'), allow_pickle=True)
            except:
                bboxes = [[],[],[]]
            self.all_bboxes.append(bboxes)
	
    def __len__(self):
        '''
        :return: 图像的数量
        '''
        return len(self.images)
    

    def get_imagename(self, index):
        ''' 
        :param: index 图像索引，范围为0-len(self.images)
        :return: 图像文件名（包含后缀）
        '''
        return self.images[index]

    def load_image_gt(self, index):
        '''
        :param index: 图像索引，范围为0-len(self.images).
        :return: 图像数据(numpy)，字符的框，词的框，文本行的框，图像名.
        '''

        assert index >= 0 and index < len(self.images)    
        img_path = os.path.join(self.layout_folder, 'image/' + self.images[index])
        #print("********* load image ******* ", img_path)
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # mask_path = os.path.join(self.layout_folder, 'image/mask_' + self.images[index])
        # image_mask = cv2.imread(mask_path, 0)#, cv2.IMREAD_COLOR)

        bboxes = self.all_bboxes[index]
        word_bboxes = bboxes[0]
        line_bboxes = bboxes[1]
        char_bboxes = bboxes[2] 
        # return image, char_bboxes, word_bboxes, np.ones((image.shape[0], image.shape[1]), np.float32), [], line_bboxes, self.image[index].split(".")[0] 

        return image, char_bboxes, word_bboxes, line_bboxes, self.images[index].split(".")[0] 

    def crop_image_line(self, index):
        '''
        description: 
            获取图像中的所有文本行图像及其文本信(目前只支持直线文本). TODO:扭曲文本行裁剪
        param
            index(int): 文档图像索引
            warp(bool): 是否扭曲  
        return: 
            imagesOfLine(list): 文本行图像列表
            textOfLine(list): 行文本列表
        '''
        image, _, word_bboxes, line_bboxes, imageName = self.load_image_gt(index)
        imageName = imageName + '_warp.jpg' if self.warp else imageName + '.jpg'
        imagesOfLine = [] # 文本行图像列表
        textOfLine = []   # 行文本列表
        for i in range(line_bboxes.shape[0]):      
            cropped = image[line_bboxes[i][1]:line_bboxes[i][5], line_bboxes[i][0]:line_bboxes[i][2], :]
            text = ""                # 行文本
            for wb in word_bboxes[i]:  # 提取词语文字并合并为行文本
                text += wb[-1]
            text = text.strip()

            '''
            for vis
            '''
            # print("text of line: \n", text)
            # cv2.imshow("line image", cropped)
            # cv2.waitKey()           
            '''
            save line image and line text for ocr train/eval
            '''
            save_name = str(index) + "_" + str(i) + "line.jpg"
            cv2.imwrite(os.path.join(self.line_image_dir, save_name), cropped)
            with open(os.path.join(self.layout_folder, "gt.txt"), 'a') as f:
                f.write('line_image/' + save_name)
                f.write("\t" + text)
                f.write("\n")
            imagesOfLine.append(cropped)  
            textOfLine.append(text)       
        return imagesOfLine, textOfLine

# 以下代码根据不用算法进行相应的函数调用
# 导入识别模块
# import tools.infer.predict_rec as predict_rec
# import tools.infer.utility as utility

if __name__ == '__main__':
    work_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    image_dir = os.path.join(work_dir, "synthetic_data")
    dataloader = SynthLayout(image_dir)
    nums = dataloader.__len__()  

    import shutil  
    if os.path.exists('./synthetic_data/line_image/'):  
        shutil.rmtree('./synthetic_data/line_image/')
    if os.path.exists('./synthetic_data/gt.txt'):
        os.remove("./synthetic_data/gt.txt") 
    os.mkdir('./synthetic_data/line_image/')  

    for i in range(nums): 
        # data for ocr
        imagesOfLine, textOfLine = dataloader.crop_image_line(i)

#         # recognization
#         textRecognizer = predict_rec.TextRecognizer(args)
#         rec_res, elapse = textRecognizer(imagesOfLine)
#         print("----------image num : {}----------\nline num: {}, rec_res num  : {}, elapse : {}".format(i, len(rec_res), len(imagesOfLine), elapse))
#         # caculate edit distance and accuracy of ocr
#         assert len(imagesOfLine) == len(rec_res)
#         for o,p in zip(textOfLine, rec_res):
#             print("ori text: {}\npred text: {}--------".format(o, p))
#             dist += Levenshtein.distance(o[0], p[0])
#             total_char_num += o[1]
    
#     print("***********Edit Distance: {}, Total Char Num: {}, Accuracy: {}**********".format(dist, total_char_num, 1-dist/total_char_num))
